{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate superpixel-based pseudolabels\n",
    "\n",
    "\n",
    "### Overview\n",
    "\n",
    "This is the third step for data preparation\n",
    "\n",
    "Input: normalized images\n",
    "\n",
    "Output: pseulabel label candidates for all the images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n"
     ]
    }
   ],
   "source": [
    "%reset\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import skimage\n",
    "\n",
    "from skimage.segmentation import slic\n",
    "from skimage.segmentation import mark_boundaries\n",
    "from skimage.util import img_as_float\n",
    "from skimage.measure import label \n",
    "import scipy.ndimage.morphology as snm\n",
    "from skimage import io\n",
    "import argparse\n",
    "import numpy as np\n",
    "import glob\n",
    "\n",
    "import SimpleITK as sitk\n",
    "import os\n",
    "\n",
    "to01 = lambda x: (x - x.min()) / (x.max() - x.min())\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Summary**\n",
    "\n",
    "a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background\n",
    "\n",
    "b. Generate superpixels as pseudolabels\n",
    "\n",
    "**Configurations of pseudlabels**\n",
    "\n",
    "```python\n",
    "# default setting of minimum superpixel sizes\n",
    "segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)\n",
    "# you can also try other configs\n",
    "segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_CONFIG = {'SABS':{\n",
    "                    'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',\n",
    "                    'out_dir': './SABS/sabs_CT_normalized',\n",
    "                    'fg_thresh': 1e-4\n",
    "                    },\n",
    "                  'CHAOST2':{\n",
    "                      'img_bname': f'../CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',\n",
    "                      'out_dir': './CHAOST2/chaos_MR_T2_normalized',\n",
    "                      'fg_thresh': 1e-4 + 50\n",
    "                    }\n",
    "                 }\n",
    "            \n",
    "\n",
    "DOMAIN = 'SABS'\n",
    "img_bname = DATASET_CONFIG[DOMAIN]['img_bname']\n",
    "imgs = glob.glob(img_bname)\n",
    "out_dir = DATASET_CONFIG[DOMAIN]['out_dir']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./SABS/sabs_CT_normalized/image_25.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_2.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_28.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_12.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_0.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_27.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_10.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_6.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_21.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_16.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_9.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_23.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_4.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_14.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_19.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_17.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_20.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_7.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_18.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_15.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_5.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_22.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_8.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_13.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_29.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_3.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_24.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_11.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_26.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_1.nii.gz']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./SABS/sabs_CT_normalized/image_0.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_1.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_2.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_3.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_4.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_5.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_6.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_7.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_8.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_9.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_10.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_11.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_12.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_13.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_14.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_15.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_16.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_17.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_18.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_19.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_20.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_21.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_22.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_23.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_24.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_25.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_26.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_27.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_28.nii.gz',\n",
       " './SABS/sabs_CT_normalized/image_29.nii.gz']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting\n",
    "\n",
    "# wrapper for process 3d image in 2d\n",
    "def superpix_vol(img, method = 'fezlen', **kwargs):\n",
    "    \"\"\"\n",
    "    loop through the entire volume\n",
    "    assuming image with axis z, x, y\n",
    "    \"\"\"\n",
    "    if method =='fezlen':\n",
    "        seg_func = skimage.segmentation.felzenszwalb\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "        \n",
    "    out_vol = np.zeros(img.shape)\n",
    "    for ii in range(img.shape[0]):\n",
    "        if MODE == 'MIDDLE':\n",
    "            segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "        out_vol[ii, ...] = segs\n",
    "        \n",
    "    return out_vol\n",
    "\n",
    "# thresholding the intensity values to get a binary mask of the patient\n",
    "def fg_mask2d(img_2d, thresh): # change this by your need\n",
    "    mask_map = np.float32(img_2d > thresh)\n",
    "    \n",
    "    def getLargestCC(segmentation): # largest connected components\n",
    "        labels = label(segmentation)\n",
    "        assert( labels.max() != 0 ) # assume at least 1 CC\n",
    "        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1\n",
    "        return largestCC\n",
    "    if mask_map.max() < 0.999:\n",
    "        return mask_map\n",
    "    else:\n",
    "        post_mask = getLargestCC(mask_map)\n",
    "        fill_mask = snm.binary_fill_holes(post_mask)\n",
    "    return fill_mask\n",
    "\n",
    "# remove superpixels within the empty regions\n",
    "def superpix_masking(raw_seg2d, mask2d):\n",
    "    raw_seg2d = np.int32(raw_seg2d)\n",
    "    lbvs = np.unique(raw_seg2d)\n",
    "    max_lb = lbvs.max()\n",
    "    raw_seg2d[raw_seg2d == 0] = max_lb + 1\n",
    "    lbvs = list(lbvs)\n",
    "    lbvs.append( max_lb )\n",
    "    raw_seg2d = raw_seg2d * mask2d\n",
    "    lb_new = 1\n",
    "    out_seg2d = np.zeros(raw_seg2d.shape)\n",
    "    for lbv in lbvs:\n",
    "        if lbv == 0:\n",
    "            continue\n",
    "        else:\n",
    "            out_seg2d[raw_seg2d == lbv] = lb_new\n",
    "            lb_new += 1\n",
    "    \n",
    "    return out_seg2d\n",
    "            \n",
    "def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):\n",
    "    raw_seg = superpix_vol(img)\n",
    "    fg_mask_vol = np.zeros(raw_seg.shape)\n",
    "    processed_seg_vol = np.zeros(raw_seg.shape)\n",
    "    for ii in range(raw_seg.shape[0]):\n",
    "        if verbose:\n",
    "            print(\"doing {} slice\".format(ii))\n",
    "        _fgm = fg_mask2d(img[ii, ...], fg_thresh )\n",
    "        _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)\n",
    "        fg_mask_vol[ii] = _fgm\n",
    "        processed_seg_vol[ii] = _out_seg\n",
    "    return fg_mask_vol, processed_seg_vol\n",
    "        \n",
    "# copy spacing and orientation info between sitk objects\n",
    "def copy_info(src, dst):\n",
    "    dst.SetSpacing(src.GetSpacing())\n",
    "    dst.SetOrigin(src.GetOrigin())\n",
    "    dst.SetDirection(src.GetDirection())\n",
    "    # dst.CopyInfomation(src)\n",
    "    return dst\n",
    "\n",
    "\n",
    "def strip_(img, lb):\n",
    "    img = np.int32(img)\n",
    "    if isinstance(lb, float):\n",
    "        lb = int(lb)\n",
    "        return np.float32(img == lb) * float(lb)\n",
    "    elif isinstance(lb, list):\n",
    "        out = np.zeros(img.shape)\n",
    "        for _lb in lb:\n",
    "            out += np.float32(img == int(_lb)) * float(_lb)\n",
    "            \n",
    "        return out\n",
    "    else:\n",
    "        raise Exception"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image with id 0 has finished\n",
      "image with id 1 has finished\n",
      "image with id 2 has finished\n",
      "image with id 3 has finished\n",
      "image with id 4 has finished\n",
      "image with id 5 has finished\n",
      "image with id 6 has finished\n",
      "image with id 7 has finished\n",
      "image with id 8 has finished\n",
      "image with id 9 has finished\n",
      "image with id 10 has finished\n",
      "image with id 11 has finished\n",
      "image with id 12 has finished\n",
      "image with id 13 has finished\n",
      "image with id 14 has finished\n",
      "image with id 15 has finished\n",
      "image with id 16 has finished\n",
      "image with id 17 has finished\n",
      "image with id 18 has finished\n",
      "image with id 19 has finished\n",
      "image with id 20 has finished\n",
      "image with id 21 has finished\n",
      "image with id 22 has finished\n",
      "image with id 23 has finished\n",
      "image with id 24 has finished\n",
      "image with id 25 has finished\n",
      "image with id 26 has finished\n",
      "image with id 27 has finished\n",
      "image with id 28 has finished\n",
      "image with id 29 has finished\n"
     ]
    }
   ],
   "source": [
    "# Generate pseudolabels for every image and save them\n",
    "for img_fid in imgs:\n",
    "# img_fid = imgs[0]\n",
    "\n",
    "    idx = os.path.basename(img_fid).split(\"_\")[-1].split(\".nii.gz\")[0]\n",
    "    im_obj = sitk.ReadImage(img_fid)\n",
    "\n",
    "    out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )\n",
    "    out_fg_o = sitk.GetImageFromArray(out_fg ) \n",
    "    out_seg_o = sitk.GetImageFromArray(out_seg )\n",
    "\n",
    "    out_fg_o = copy_info(im_obj, out_fg_o)\n",
    "    out_seg_o = copy_info(im_obj, out_seg_o)\n",
    "    seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')\n",
    "    msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')\n",
    "    sitk.WriteImage(out_fg_o, msk_fid)\n",
    "    sitk.WriteImage(out_seg_o, seg_fid)\n",
    "    print(f'image with id {idx} has finished')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
